{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize some results.\n",
    "\n",
    "Sometimes, its useful to visualize your model's results to see what it gets right and what it gets wrong. This notebook guides you through iterativing through the dataset and visualizing some results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from iterator import SmartIterator\n",
    "from utils.visualization_utils import get_att_map, objdict, get_dict, add_attention, add_bboxes, get_bbox_from_heatmap, add_bbox_to_image\n",
    "from keras.models import load_model\n",
    "from models import ReferringRelationshipsModel\n",
    "from keras.utils import to_categorical\n",
    "import numpy as np\n",
    "import os\n",
    "from PIL import Image\n",
    "import json\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import h5py\n",
    "from keras.models import Model\n",
    "import seaborn as sns\n",
    "from scipy.misc import imresize\n",
    "from urllib.request import urlopen\n",
    "from io import BytesIO\n",
    "from keras.applications.resnet50 import preprocess_input\n",
    "\n",
    "matplotlib.rcParams.update({'font.size': 34})\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choose the dataset you want to visualize.\n",
    "\n",
    "Note that you will have to point the `img_dir` variable to where you saved the images for that dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###################\n",
    "data_type = \"clevr\"\n",
    "###################\n",
    "if data_type==\"vrd\":\n",
    "    annotations_file = \"data/VRD/annotations_test.json\"\n",
    "    img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/' # You will have to change this to where your images are stored.\n",
    "    vocab_dir = os.path.join('data/VRD')\n",
    "    model_checkpoint = \"pretrained/vrd.h5\"\n",
    "elif data_type==\"clevr\":\n",
    "    annotations_file = \"data/clevr/annotations_test.json\"\n",
    "    img_dir = '/data/ranjaykrishna/clevr/images/test'  # You will have to change this to where your images are stored.\n",
    "    vocab_dir = os.path.join('data/clevr')\n",
    "    model_checkpoint = \"pretrained/clevr.h5\"\n",
    "elif data_type==\"visualgenome\":\n",
    "    annotations_file = \"data/VisualGenome/annotations_test.json\"\n",
    "    img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'  # You will have to change this to where your images are stored.\n",
    "    vocab_dir = os.path.join('data/VisualGenome')\n",
    "    model_checkpoint = \"pretrained/visualgenome.h5\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iou(y_true, y_pred, thresh=0.5, eps=10e-8):\n",
    "    y_pred = y_pred > thresh\n",
    "    intersection = (y_pred * y_true).sum(axis=1)\n",
    "    union = eps + ((y_pred + y_true)>0).sum(axis=1)\n",
    "    return list(intersection/union)\n",
    "\n",
    "def recall(y_true, y_pred, thresh=0.5, eps=10e-8):\n",
    "    y_pred = y_pred > thresh\n",
    "    tp = (y_pred * y_true).sum(axis=1)\n",
    "    fn = (1*((y_true - y_pred)>0)).sum(axis=1)\n",
    "    recall = tp/(tp+fn+eps)\n",
    "    return list(recall)\n",
    "\n",
    "def precision(y_true, y_pred, thresh=0.5, eps=10e-8):\n",
    "    y_pred = y_pred > thresh\n",
    "    tp = (y_pred * y_true).sum(axis=1)\n",
    "    p = y_pred.sum(axis=1)\n",
    "    prec = tp/(p+eps)\n",
    "    return list(prec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "annotations_test = json.load(open(annotations_file))\n",
    "predicate_dict, obj_subj_dict = get_dict(vocab_dir)\n",
    "image_ids = sorted(list(annotations_test.keys()))\n",
    "params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), \"args.json\"), \"r\")))\n",
    "params.norm_scale=1\n",
    "params.use_internal_loss = False\n",
    "params.cnn = 'resnet'\n",
    "params.batch_size = 1\n",
    "params.discovery = False\n",
    "relationships_model = ReferringRelationshipsModel(params)\n",
    "test_generator = SmartIterator(params.test_data_dir, params)\n",
    "print(' | '.join(obj_subj_dict))\n",
    "print('')\n",
    "print(' | '.join(predicate_dict))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = relationships_model.build_model()\n",
    "model.load_weights(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize some results.\n",
    "\n",
    "The `rel_range` indicates how many and starting from what index you want to visualize results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################\n",
    "rel_range = [200, 300]\n",
    "#################\n",
    "metrics = [iou, recall, precision]\n",
    "metrics = [lambda x, y: precision(x, y, thresh=0.8)]\n",
    "\n",
    "for rel_idx in range(rel_range[0], rel_range[1]):\n",
    "    inputs, outputs = test_generator[rel_idx]\n",
    "    s_pred, o_pred = model.predict(inputs)\n",
    "    \n",
    "    # Evaluate\n",
    "    results = {}\n",
    "    for metric in metrics:\n",
    "        results['s_' + metric.__name__] = metric(outputs[0], s_pred)\n",
    "        results['o_' + metric.__name__] = metric(outputs[1], o_pred)\n",
    "    \n",
    "    # visualize\n",
    "    relationship = [obj_subj_dict[int(inputs[1][0,0])], \n",
    "                    predicate_dict[int(inputs[2][0,0])], \n",
    "                    obj_subj_dict[int(inputs[3][0,0])]] \n",
    "    \n",
    "    #image_index = #TODO\n",
    "    #img = Image.open(os.path.join(img_dir, image_ids[image_index]))\n",
    "    #img = img.resize((params.input_dim, params.input_dim))\n",
    "    img = inputs[0][0] + np.array([103.939, 116.779, 123.68])\n",
    "    img = Image.fromarray(img.astype('uint8'), 'RGB')\n",
    "    att_map = get_att_map(img, np.maximum(s_pred[0],0), o_pred[0], params.input_dim, relationship)\n",
    "    plt.figure(figsize=(15, 15))\n",
    "    plt.imshow(att_map)\n",
    "    plt.axis(\"off\")\n",
    "    plt.show()\n",
    "    print(relationship)\n",
    "    print(results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
